package flow

import (
	"context"
	"sync"
	"sync/atomic"
)

// 有向无环图作业
type DAGJob struct {
	// 结束标识，该标识由结束节点写入
	done chan struct{}
	//保证并发时只写入一次
	once *sync.Once
	//有节点出错时终止流程标记
	alreadyDone bool
	//开始节点
	start *Node
	//结束节点
	End *Node
	//所有经过的边，边连接了节点
	edges []*Edge
}

type Edge struct {
	From *Node
	To   *Node
}

type Node struct {
	// 依赖的边
	Dependency []*Edge
	//表示依赖的边有多少个已执行完成，用于判断该节点是否可以执行了
	DepCompleted int32
	//执行任务的函数
	Task RunFn
	//节点的子边
	Children []*Edge
}

type EndNode struct {
	//节点执行完成，往该done写入消息，和workjob中的done共用
	done chan struct{}
	//并发控制，确保只往done中写入一次
	once *sync.Once
}

// 执行函数
type RunFn = func()

func NewNode(task RunFn) *Node {
	return &Node{
		Task: task,
	}
}

// 新建一条边
func NewEdge(from *Node, to *Node) *Edge {
	edge := &Edge{
		From: from,
		To:   to,
	}

	//该条边是from节点的出边
	from.Children = append(from.Children, edge)
	//该条边是to节点的入边
	to.Dependency = append(to.Dependency, edge)

	return edge
}

func (job *DAGJob) AddStartNode(node *Node) {
	job.edges = append(job.edges, NewEdge(job.start, node))
}

func (job *DAGJob) AddEdge(from *Node, to *Node) {
	job.edges = append(job.edges, NewEdge(from, to))
}

func (job *DAGJob) ConnectToEnd(node *Node) {
	job.edges = append(job.edges, NewEdge(node, job.End))
}

func (job *DAGJob) StartWithContext(ctx context.Context) {
	job.start.ExecuteWithContext(ctx, job)
}

func (job *DAGJob) WaitDone() {
	<-job.done
	close(job.done)
}

func (job *DAGJob) interruptDone() {
	job.alreadyDone = true
	job.once.Do(func() { job.done <- struct{}{} })
}

func (n *Node) dependencyHasDone() bool {
	//该节点没有依赖的前置节点，不需要等待，直接返回true
	if n.Dependency == nil {
		return true
	}

	//如果该节点只有一个依赖的前置节点，也直接返回
	if len(n.Dependency) == 1 {
		return true
	}

	//这里将依赖的节点加1，说明有一个依赖的节点完成了
	atomic.AddInt32(&n.DepCompleted, 1)

	//判断当前依赖的节点数量是否和依赖的节点相等，相等，说明都运行完了
	return n.DepCompleted == int32(len(n.Dependency))
}

func (n *Node) ExecuteWithContext(ctx context.Context, job *DAGJob) {
	//所依赖的前置节点没有运行完成，则直接返回
	if !n.dependencyHasDone() {
		return
	}
	//有节点运行出错，终止流程的执行
	if ctx.Err() != nil {
		job.interruptDone()
		return
	}

	//节点具体的运行逻辑
	if n.Task != nil {
		n.Task()
	}

	// 运行子节点
	if len(n.Children) > 0 {
		for idx := 1; idx < len(n.Children); idx++ {
			// 并行执行
			go func(child *Edge) {
				child.To.ExecuteWithContext(ctx, job)
			}(n.Children[idx])
		}

		n.Children[0].To.ExecuteWithContext(ctx, job)
	}

}

func NewDAGJob() *DAGJob {
	job := &DAGJob{
		//开始节点，所有具体的节点都是它的子节点，没有具体的执行逻辑，只为出发其他节点的执行
		start: &Node{Task: nil},
		done:  make(chan struct{}, 1),
		once:  &sync.Once{},
	}

	//加入结束节点
	endNode := &EndNode{
		done: job.done,
		once: job.once,
	}
	job.End = NewNode(func() {
		endNode.once.Do(func() {
			endNode.done <- struct{}{}
		})
	})

	return job
}
